# Simulate coalescent process with heterochronous sampling times

# Assumptions and modifications
# - truncates to a time period of interest
# - does runs across various sample numbers (data)
# - deposits batch runs in a single folder
# - simulate a single N(t) trajectory
# - samples placed uniformly across time

# Clean the workspace and console
closeAllConnections()
rm(list=ls())
cat("\014")  
graphics.off()

# Packages for phylodyn
library("sp")
library("devtools")
library("INLA")
library("spam")
library("ape")
library("phylodyn")

# Set working directory to source
this.dir <- dirname(parent.frame(2)$ofile)
setwd(this.dir)

# Function to write simple csv files to correct path
tableWrite <- function(val, name, pathname) {
  # Add path to name
  str0 <- paste(c(pathname, name), collapse = "")
  # Write table
  write.table(val, str0, row.names=FALSE, col.names=FALSE, sep=",")
}

# Define a middling bottleneck
bottle_traj <- function (t) 
{
  result = rep(0, length(t))
  result[t <= 15] <- 200
  result[t > 15 & t < 40] <- 20
  result[t >= 40] <- 200
  return(result)
}

# Define a boom-bust with a later changepoint and an offset
boom_traj <- function (t, bust = 20, scale = 1000, offset = 100) 
{
  result = rep(0, length(t))
  result[t <= bust] = scale*exp(t[t <= bust] - bust) + offset
  result[t > bust] = scale*exp(bust - t[t > bust]) + offset
  return(result)
}

# Define a logistic trajectory with larger N
N = 500; N0 = 0.01*N
logis_traj <- function (t, offset = 0, a = 2) 
{
  t = t + offset
  result = rep(0, length(t))
  result[(t%%12) <= 6] = N0 + N/(1 + exp((3 - (t[(t%%12) <= 6]%%12)) * a))
  result[(t%%12) > 6] = N0 + N/(1 + exp(((t[(t%%12) > 6]%%12) - 12 + 3) * a))
  return(result)
}

# Main code for heterochronous simulations ----------------------------------------------------------

# Choose trajectory case
trajCase = 3
trajNames = c('cyclicSamps', 'bottleSamps', 'boomSamps', 'steepSamps', 'logisSamps')

# Choose trajectory type
trajType = switch(trajCase,
                  "1"= cyclic_traj,
                  "2"= bottle_traj,
                  "3"= boom_traj,
                  "4"= steep_cyc_traj,
                  "5"= logis_traj
)
traj = trajType
trajVal = trajNames[trajCase]


# Range of sample numbers to loop across 
nSamps = seq(401, 801, 100)
if(trajCase == 4){
  nSamps = seq(801, 2001, 100)
}
numRuns = length(nSamps)

# Uniform sampling across time
all_samp_end = 40 #often set to 60 for cyclic and bottle
ndivs = 20
# Sample times
samp_times = seq(0, all_samp_end, length.out = ndivs)

# Period of truncation and extra initial samples
truncTime = 85
tsamp0 = 20
if(trajCase == 2){
  tsamp1 = 20; tsamp2 = 20
}else{
  tsamp1 = 0; tsamp2 = 0
}

# Create folder for traj specific results
trajName = paste(c(trajVal, '_', numRuns), collapse = '')
dir.create(file.path(this.dir, trajName))
pathf = paste(c(this.dir, '/', trajName, '/'), collapse = "")

# Coalescent events and max time for each trajectory
nc = rep(0, numRuns); tmax = rep(0, numRuns)

for (i in 1:numRuns) {
  # Number of samples introduced at each time
  nsamps = nSamps[i] - tsamp0 - tsamp1 - tsamp2

  # Sample number and times 
  samps = c(rep(floor(nsamps/ndivs), ndivs-1), nsamps-(ndivs-1)*floor(nsamps/ndivs))
  # Extra samples at beginning (and near end)
  samps[1] = samps[1] + tsamp0
  id1 = which(samp_times >= 40)
  samps[id1[1]-1] = samps[id1[1]-1] + tsamp1
  samps[id1[1]] = samps[id1[1]] + tsamp2
  
  # Simulate genealogy and get all times
  gene = coalsim(samp_times = samp_times, n_sampled = samps, traj = traj, lower_bound = 10, method = "thin")
  coal_times = gene$coal_times
  coalLin = gene$lineages
  
  # Truncate trees to trunc
  idtrunc = coal_times <= truncTime
  coal_times = coal_times[idtrunc]
  coalLin = coalLin[idtrunc]
  
  # TMRCA and no. coalescent events
  tmax[i] = max(coal_times)
  nc[i] = length(coal_times)
  
  # Export teajectory specific data for Matlab
  tableWrite(coal_times, paste(c('coaltimes', i, '.csv'), collapse = ''), pathf)
  tableWrite(coalLin, paste(c('coalLin', i, '.csv'), collapse = ''), pathf)
  tableWrite(samps, paste(c('sampIntro', i, '.csv'), collapse = ''), pathf)
}

# No. coalescences, samples and TMRCA
tableWrite(nc, 'nc.csv', pathf)
tableWrite(nSamps, 'ns.csv', pathf)
tableWrite(tmax, 'tmax.csv', pathf)
tableWrite(truncTime, 'truncTime.csv', pathf)
tableWrite(samp_times, 'samptimes.csv', pathf)

# True population size
t = seq(0, max(tmax), length=20000); y = traj(t)
tableWrite(t, 'trajt.csv', pathf)
tableWrite(y, 'trajy.csv', pathf)